import typing
from typing import Union

import torch
from torch import Tensor

import torch_geometric.typing
from torch_geometric import is_compiling
from torch_geometric.utils import is_sparse
from torch_geometric.typing import Size, SparseTensor
{% for module in modules %}
from {{module}} import *
{%- endfor %}


{% include "collect.jinja" %}


def edge_updater(
    self,
    edge_index: Union[Tensor, SparseTensor],
{%- for param in signature.param_dict.values() %}
    {{param.name}}: {{param.type_repr}},
{%- endfor %}
    size: Size = None,
) -> {{signature.return_type_repr}}:

    mutable_size = self._check_input(edge_index, size)

    kwargs = self.{{collect_name}}(
        edge_index,
{%- for name in signature.param_dict %}
        {{name}},
{%- endfor %}
        mutable_size,
    )

    # Begin Edge Update Forward Pre Hook #######################################
    if not torch.jit.is_scripting() and not is_compiling():
        for hook in self._edge_update_forward_pre_hooks.values():
            hook_kwargs = dict(
{%- for name in collect_param_dict %}
                {{name}}=kwargs.{{name}},
{%- endfor %}
            )
            res = hook(self, (edge_index, size, hook_kwargs))
            if res is not None:
                edge_index, size, hook_kwargs = res
                kwargs = CollectArgs(
{%- for name in collect_param_dict %}
                    {{name}}=hook_kwargs['{{name}}'],
{%- endfor %}
                )
    # End Edge Update Forward Pre Hook #########################################

    out = self.edge_update(
{%- for name in collect_param_dict %}
        {{name}}=kwargs.{{name}},
{%- endfor %}
    )

    # Begin Edge Update Forward Hook ###########################################
    if not torch.jit.is_scripting() and not is_compiling():
        for hook in self._edge_update_forward_hooks.values():
            hook_kwargs = dict(
{%- for name in collect_param_dict %}
                {{name}}=kwargs.{{name}},
{%- endfor %}
            )
            res = hook(self, (edge_index, size, hook_kwargs), out)
            out = res if res is not None else out
    # End Edge Update Forward Hook #############################################

    return out
